# medmnist_datamodule.py
from copy import deepcopy
from pathlib import Path
from typing import Optional

import torch
from hydra.utils import call
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision import transforms as transform_lib
import medmnist
from medmnist import INFO

from src.data.data_utils import split_subsets_train_val, split_dataset_train_val, add_attrs, \
    pathological_client_data_split, dirichlet_split


class MedMNISTDataModule(LightningDataModule):
    def __init__(
            self,
            data_flag: str,
            split_function,
            data_dir: str = Path("/tmp"),
            val_split: float = 0.1,
            num_workers: int = 16,
            normalize: bool = False,
            seed: int = 42,
            batch_size: int = 32,
            num_clients: int = 3,
            fair_val: bool = False,
            *args,
            **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.data_flag = data_flag
        self.data_dir = data_dir
        self.val_split = val_split
        self.num_workers = num_workers
        self.normalize = normalize
        self.seed = seed
        self.batch_size = batch_size
        self.num_clients = num_clients
        self.fair_val = fair_val
        self.split_function = split_function

        self.info = INFO[data_flag]
        self.task = self.info['task']
        self.n_channels = self.info['n_channels']
        self.num_classes = len(self.info['label'])
        self.DataClass = getattr(medmnist, self.info['python_class'])

        self.ds_mean = (0.5,) * self.n_channels
        self.ds_std = (0.5,) * self.n_channels

        self.datasets_train: [Subset] = ...
        self.datasets_val: [Subset] = ...
        self.train_dataset: Dataset = ...
        self.val_dataset: Dataset = ...
        self.test_dataset: Dataset = ...

        self.current_client_idx = 0

    def prepare_data(self):
        self.DataClass(split='train', download=True, root=self.data_dir)
        self.DataClass(split='val', download=True, root=self.data_dir)
        self.DataClass(split='test', download=True, root=self.data_dir)

    def setup(self, stage: Optional[str] = None):
        if stage == "fit" or stage is None:
            self.train_dataset = self.DataClass(
                split='train', transform=self.aug_transforms, download=False, root=self.data_dir
            )
            self.val_dataset = self.DataClass(
                split='val', transform=self.default_transforms, download=False, root=self.data_dir
            )

            if hasattr(self.train_dataset, 'labels'):
                self.train_dataset.targets = torch.Tensor(self.train_dataset.labels.flatten()).to(torch.long)
            if hasattr(self.val_dataset, 'labels'):
                self.val_dataset.targets = torch.Tensor(self.val_dataset.labels.flatten()).to(torch.long)


            self.datasets_train = call(self.split_function, dataset=self.train_dataset)
            # self.datasets_val = [deepcopy(self.val_dataset) for _ in range(self.num_clients)]
            self.datasets_val = [Subset(self.val_dataset, torch.arange(len(self.val_dataset))) for _ in
                                 range(self.num_clients)]

            add_attrs(self.datasets_train, self.datasets_val)

            for idx, subset in enumerate(self.datasets_train):
                print(f"Client {idx} training subset size: {len(subset)}")

            print(f"Clients' validation subset size: {len(self.val_dataset)}")

            # for idx, subset in enumerate(self.datasets_val):
            #     print(f"Client {idx} validation subset size: {len(subset)}")



    def transfer_setup(self):
        self.train_dataset = self.DataClass(
            split='train', transform=self.aug_transforms, download=False, root=self.data_dir
        )
        self.val_dataset = self.DataClass(
            split='val', transform=self.default_transforms, download=False, root=self.data_dir
        )
        self.test_dataset = self.DataClass(
            split='test', transform=self.default_transforms, download=False, root=self.data_dir
        )

    def next_client(self):
        self.current_client_idx += 1
        assert self.current_client_idx < self.num_clients, "Client number shouldn't exceed selected number of clients"

    def train_dataloader(self):
        loader = DataLoader(
            self.datasets_train[self.current_client_idx],
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=True,
        )
        return loader

    def val_dataloader(self):
        loader = DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
        )
        return loader

    def test_dataloader(self):
        dataset = self.DataClass(
            split='test', transform=self.default_transforms, download=False, root=self.data_dir
        )
        if hasattr(dataset, 'labels'):
            dataset.targets = torch.Tensor(dataset.labels.flatten()).to(torch.long)
        loader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
        )
        return loader

    @property
    def default_transforms(self):
        medmnist_transforms = [
            transform_lib.ToTensor(),
        ]
        if self.normalize:
            medmnist_transforms.append(transform_lib.Normalize(mean=self.ds_mean, std=self.ds_std))
        return transform_lib.Compose(medmnist_transforms)

    @property
    def aug_transforms(self):
        medmnist_transforms = [
            transform_lib.RandomRotation(degrees=15),
            transform_lib.RandomHorizontalFlip(),
            transform_lib.ToTensor(),
        ]
        if self.normalize:
            medmnist_transforms.append(transform_lib.Normalize(mean=self.ds_mean, std=self.ds_std))
        return transform_lib.Compose(medmnist_transforms)


# # Example usage for PathMNIST
# pathmnist_dm = MedMNISTDataModule(data_flag='pathmnist', split_function=pathological_client_data_split)
# pathmnist_dm.prepare_data()
# pathmnist_dm.setup()
#
# # Example usage for OrganMNIST(axial)
# organmnist_axial_dm = MedMNISTDataModule(data_flag='organmnist_axial', split_function=pathological_client_data_split)
# organmnist_axial_dm.prepare_data()
# organmnist_axial_dm.setup()
